name: ljspeech-asr

resources:
  candidates:
  - {accelerators: T4}
  - {accelerators: V100}

workdir: ./examples/benchmark/keras_asr

setup: |
  conda create -n keras python=3.8 -y
  conda activate keras

  # Install SkyCallback
  pip install "git+https://github.com/skypilot-org/skypilot.git#egg=sky-callback&subdirectory=sky/callbacks/"

  # User setup
  pip install numpy pandas tensorflow
  git clone https://github.com/keras-team/keras-io.git
  cd keras-io
  git checkout 49a16474cc5bbf86792bb7557a70d13fdb7a9c97

  # Apply the patch to enable SkyCallback
  git apply ../callback.patch

run: |
  conda activate keras
  cd keras-io/examples/audio/
  python transformer_asr.py
